"""
Phase 9: Robustness Matrix
Tests Z_1 coefficient stability across specifications and subsamples
for real bond yields and inflation.

Outputs:
  - Table 20: Robustness matrix for real_bond_10y (phase9_table20_robustness_rates.md)
  - Table 21: Robustness matrix for inflation (phase9_table21_robustness_inflation.md)
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


# ── helpers ──────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars, min_obs=50):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    cols = [c for c in cols if c in panel.columns]
    df = panel[cols].dropna()
    if len(df) < min_obs:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna (need {min_obs})")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def star(p):
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_cell(res, var):
    """Format coefficient as 'coef(se)***' for a robustness cell."""
    if res is None or var not in res["coefs"]:
        return "--"
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    if np.isnan(c) or np.isnan(s):
        return "--"
    return f"{c:.3f}{star(p)} ({s:.3f})"


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 9: ROBUSTNESS MATRIX")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    # ── Prepare subsamples ────────────────────────────────────────────────
    oecd = panel[panel["iso3"].isin(OECD_38)].copy()
    pre_gfc = panel[panel["year"] <= 2007].copy()
    post_gfc = panel[panel["year"] >= 2008].copy()

    # Income terciles (already in panel from phase1)
    if "income_tercile" in panel.columns:
        low_inc = panel[panel["income_tercile"] == "low"].copy()
        mid_inc = panel[panel["income_tercile"] == "middle"].copy()
        high_inc = panel[panel["income_tercile"] == "high"].copy()
    else:
        # Build on the fly
        gdp_med = panel.groupby("iso3")["gdp_pc_ppp"].median()
        t1, t2 = gdp_med.quantile(0.33), gdp_med.quantile(0.67)
        income_map = gdp_med.apply(
            lambda x: "low" if x <= t1 else ("middle" if x <= t2 else "high")
        )
        panel["income_tercile"] = panel["iso3"].map(income_map)
        low_inc = panel[panel["income_tercile"] == "low"].copy()
        mid_inc = panel[panel["income_tercile"] == "middle"].copy()
        high_inc = panel[panel["income_tercile"] == "high"].copy()

    subsamples = [
        ("Full", panel),
        ("OECD", oecd),
        ("Pre-GFC", pre_gfc),
        ("Post-GFC", post_gfc),
        ("Low inc", low_inc),
        ("Mid inc", mid_inc),
        ("High inc", high_inc),
    ]

    print(f"OECD:     {len(oecd):,} obs, {oecd['iso3'].nunique()} countries")
    print(f"Pre-GFC:  {len(pre_gfc):,} obs")
    print(f"Post-GFC: {len(post_gfc):,} obs")
    print(f"Low inc:  {len(low_inc):,} obs, {low_inc['iso3'].nunique()} countries")
    print(f"Mid inc:  {len(mid_inc):,} obs, {mid_inc['iso3'].nunique()} countries")
    print(f"High inc: {len(high_inc):,} obs, {high_inc['iso3'].nunique()} countries")

    # ── Ensure needed variables exist ─────────────────────────────────────
    # Exclude Japan and Germany subsample
    excl_jpn_deu = panel[~panel["iso3"].isin(["JPN", "DEU"])].copy()

    # Z_1 * mi_index interaction
    if "Z_1_x_mi" not in panel.columns:
        panel["Z_1_x_mi"] = panel["Z_1"] * panel["mi_index"]
    for name, df in subsamples:
        if "Z_1_x_mi" not in df.columns:
            df["Z_1_x_mi"] = df["Z_1"] * df["mi_index"]
    if "Z_1_x_mi" not in excl_jpn_deu.columns:
        excl_jpn_deu["Z_1_x_mi"] = excl_jpn_deu["Z_1"] * excl_jpn_deu["mi_index"]

    z_vars = ["Z_1", "Z_2", "Z_3"]
    controls_rate = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    controls_infl = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]

    # ══════════════════════════════════════════════════════════════════════
    # TABLE 20: Robustness for real_bond_10y
    # ══════════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("TABLE 20: ROBUSTNESS — real_bond_10y")
    print("=" * 70)

    dv_rate = "real_bond_10y"

    # Define specifications: (label, report_var, rhs_builder_func)
    # rhs_builder_func takes (subsample_df) and returns rhs list or None
    def spec_baseline(df):
        return z_vars + controls_rate, "Z_1"

    def spec_alt_demo(df):
        alt_demos = ["old_dep", "youth_dep", "working_age_share"]
        available = [v for v in alt_demos if v in df.columns]
        if len(available) < 3:
            return None, None
        return available + controls_rate, "old_dep"

    def spec_lag5(df):
        lag_vars = ["Z_1_lag5", "Z_2_lag5", "Z_3_lag5"]
        available = [v for v in lag_vars if v in df.columns]
        if len(available) < 3:
            return None, None
        return available + controls_rate, "Z_1_lag5"

    def spec_first_diff(df):
        diff_vars = ["dZ_1", "dZ_2", "dZ_3"]
        available = [v for v in diff_vars if v in df.columns]
        if len(available) < 3:
            return None, None
        return available + controls_rate, "dZ_1"

    def spec_add_investment(df):
        if "gross_investment_gdp" not in df.columns:
            return None, None
        return z_vars + controls_rate + ["gross_investment_gdp"], "Z_1"

    def spec_add_gdppc(df):
        if "gdp_pc_ppp" not in df.columns:
            return None, None
        return z_vars + controls_rate + ["gdp_pc_ppp"], "Z_1"

    def spec_excl_jpn_deu(df):
        # This uses special subsample; handled separately
        return z_vars + controls_rate, "Z_1"

    def spec_predetermined(df):
        if "oadr_plus20" not in df.columns:
            return None, None
        return ["oadr_plus20", "Z_2", "Z_3"] + controls_rate, "oadr_plus20"

    def spec_trilemma(df):
        if "mi_index" not in df.columns or "Z_1_x_mi" not in df.columns:
            return None, None
        return z_vars + ["mi_index", "Z_1_x_mi"] + controls_rate, "Z_1"

    specifications = [
        ("1. Baseline", spec_baseline),
        ("2. Alt demographics", spec_alt_demo),
        ("3. 5-year lag", spec_lag5),
        ("4. First differences", spec_first_diff),
        ("5. + Investment/GDP", spec_add_investment),
        ("6. + GDP per capita", spec_add_gdppc),
        ("7. Excl JPN & DEU", spec_excl_jpn_deu),
        ("8. Predetermined (oadr+20)", spec_predetermined),
        ("9. Trilemma interaction", spec_trilemma),
    ]

    # Run all specifications across all subsamples
    col_labels = [name for name, _ in subsamples]
    rows_data = []

    for spec_label, spec_fn in specifications:
        print(f"\n  {spec_label}:")
        row_cells = []

        for sub_label, sub_df in subsamples:
            # Special handling for spec 7 (exclude JPN & DEU)
            if spec_label == "7. Excl JPN & DEU":
                sub_df = sub_df[~sub_df["iso3"].isin(["JPN", "DEU"])].copy()
                if "Z_1_x_mi" not in sub_df.columns:
                    sub_df["Z_1_x_mi"] = sub_df["Z_1"] * sub_df.get("mi_index", np.nan)

            result_tuple = spec_fn(sub_df)
            if result_tuple is None or result_tuple[0] is None:
                row_cells.append("--")
                print(f"    {sub_label}: SKIP (missing vars)")
                continue

            rhs, report_var = result_tuple

            # Check all RHS vars exist in subsample
            missing = [v for v in rhs if v not in sub_df.columns]
            if missing:
                row_cells.append("--")
                print(f"    {sub_label}: SKIP (missing: {missing})")
                continue

            res = run_model(sub_df, dv_rate, rhs)
            cell = fmt_cell(res, report_var)
            row_cells.append(cell)

            if res and report_var in res["coefs"]:
                print(f"    {sub_label}: {report_var}={cell}, R2={res['r2']:.3f}, "
                      f"N={res['nobs']}")
            else:
                print(f"    {sub_label}: {cell}")

        rows_data.append((spec_label, row_cells))

    # Build markdown table
    md_lines = ["# Table 20: Robustness Matrix -- Z_1 on Real 10-Year Bond Yield\n"]
    md_lines.append("Each cell reports the key demographic coefficient (SE) with significance "
                    "stars (* p<0.1, ** p<0.05, *** p<0.01).\n")
    md_lines.append("| Specification | " + " | ".join(col_labels) + " |")
    md_lines.append("|" + "---|" * (len(col_labels) + 1))

    for spec_label, cells in rows_data:
        md_lines.append(f"| {spec_label} | " + " | ".join(cells) + " |")

    md_lines.append("")
    md_lines.append("*Notes:*")
    md_lines.append("- *Row 1: Baseline (Z_1,Z_2,Z_3 + controls). Reports Z_1.*")
    md_lines.append("- *Row 2: Replace Z with old_dep, youth_dep, working_age_share. "
                    "Reports old_dep.*")
    md_lines.append("- *Row 3: 5-year lagged demographics. Reports Z_1_lag5.*")
    md_lines.append("- *Row 4: First differences. Reports dZ_1.*")
    md_lines.append("- *Row 5: Add gross_investment_gdp control. Reports Z_1.*")
    md_lines.append("- *Row 6: Add gdp_pc_ppp control. Reports Z_1.*")
    md_lines.append("- *Row 7: Exclude Japan and Germany. Reports Z_1.*")
    md_lines.append("- *Row 8: Replace Z_1 with oadr_plus20 (predetermined). "
                    "Reports oadr_plus20.*")
    md_lines.append("- *Row 9: Add mi_index and Z_1*mi_index. Reports Z_1.*")
    md_lines.append("- *Controls: rgdp_growth, inflation, fiscal_bal_gdp, kaopen, nfa_gdp_lag.*")

    md_t20 = "\n".join(md_lines)
    print("\n" + md_t20)
    t20_path = TABLE_DIR / "phase9_table20_robustness_rates.md"
    t20_path.write_text(md_t20)
    print(f"\nSaved: {t20_path}")

    # ══════════════════════════════════════════════════════════════════════
    # TABLE 21: Robustness for inflation
    # ══════════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("TABLE 21: ROBUSTNESS — inflation")
    print("=" * 70)

    dv_infl = "inflation"

    # Redefine spec builders for inflation controls (output_gap instead of inflation)
    def spec_baseline_infl(df):
        return z_vars + controls_infl, "Z_1"

    def spec_alt_demo_infl(df):
        alt_demos = ["old_dep", "youth_dep", "working_age_share"]
        available = [v for v in alt_demos if v in df.columns]
        if len(available) < 3:
            return None, None
        return available + controls_infl, "old_dep"

    def spec_lag5_infl(df):
        lag_vars = ["Z_1_lag5", "Z_2_lag5", "Z_3_lag5"]
        available = [v for v in lag_vars if v in df.columns]
        if len(available) < 3:
            return None, None
        return available + controls_infl, "Z_1_lag5"

    def spec_first_diff_infl(df):
        diff_vars = ["dZ_1", "dZ_2", "dZ_3"]
        available = [v for v in diff_vars if v in df.columns]
        if len(available) < 3:
            return None, None
        return available + controls_infl, "dZ_1"

    def spec_add_investment_infl(df):
        if "gross_investment_gdp" not in df.columns:
            return None, None
        return z_vars + controls_infl + ["gross_investment_gdp"], "Z_1"

    def spec_add_gdppc_infl(df):
        if "gdp_pc_ppp" not in df.columns:
            return None, None
        return z_vars + controls_infl + ["gdp_pc_ppp"], "Z_1"

    def spec_excl_jpn_deu_infl(df):
        return z_vars + controls_infl, "Z_1"

    def spec_predetermined_infl(df):
        if "oadr_plus20" not in df.columns:
            return None, None
        return ["oadr_plus20", "Z_2", "Z_3"] + controls_infl, "oadr_plus20"

    def spec_trilemma_infl(df):
        if "mi_index" not in df.columns or "Z_1_x_mi" not in df.columns:
            return None, None
        return z_vars + ["mi_index", "Z_1_x_mi"] + controls_infl, "Z_1"

    specifications_infl = [
        ("1. Baseline", spec_baseline_infl),
        ("2. Alt demographics", spec_alt_demo_infl),
        ("3. 5-year lag", spec_lag5_infl),
        ("4. First differences", spec_first_diff_infl),
        ("5. + Investment/GDP", spec_add_investment_infl),
        ("6. + GDP per capita", spec_add_gdppc_infl),
        ("7. Excl JPN & DEU", spec_excl_jpn_deu_infl),
        ("8. Predetermined (oadr+20)", spec_predetermined_infl),
        ("9. Trilemma interaction", spec_trilemma_infl),
    ]

    rows_data_infl = []

    for spec_label, spec_fn in specifications_infl:
        print(f"\n  {spec_label}:")
        row_cells = []

        for sub_label, sub_df in subsamples:
            if spec_label == "7. Excl JPN & DEU":
                sub_df = sub_df[~sub_df["iso3"].isin(["JPN", "DEU"])].copy()
                if "Z_1_x_mi" not in sub_df.columns:
                    sub_df["Z_1_x_mi"] = sub_df["Z_1"] * sub_df.get("mi_index", np.nan)

            result_tuple = spec_fn(sub_df)
            if result_tuple is None or result_tuple[0] is None:
                row_cells.append("--")
                print(f"    {sub_label}: SKIP (missing vars)")
                continue

            rhs, report_var = result_tuple

            missing = [v for v in rhs if v not in sub_df.columns]
            if missing:
                row_cells.append("--")
                print(f"    {sub_label}: SKIP (missing: {missing})")
                continue

            res = run_model(sub_df, dv_infl, rhs)
            cell = fmt_cell(res, report_var)
            row_cells.append(cell)

            if res and report_var in res["coefs"]:
                print(f"    {sub_label}: {report_var}={cell}, R2={res['r2']:.3f}, "
                      f"N={res['nobs']}")
            else:
                print(f"    {sub_label}: {cell}")

        rows_data_infl.append((spec_label, row_cells))

    # Build markdown table
    md_lines = ["# Table 21: Robustness Matrix -- Z_1 on Inflation\n"]
    md_lines.append("Each cell reports the key demographic coefficient (SE) with significance "
                    "stars (* p<0.1, ** p<0.05, *** p<0.01).\n")
    md_lines.append("| Specification | " + " | ".join(col_labels) + " |")
    md_lines.append("|" + "---|" * (len(col_labels) + 1))

    for spec_label, cells in rows_data_infl:
        md_lines.append(f"| {spec_label} | " + " | ".join(cells) + " |")

    md_lines.append("")
    md_lines.append("*Notes:*")
    md_lines.append("- *Row 1: Baseline (Z_1,Z_2,Z_3 + controls). Reports Z_1.*")
    md_lines.append("- *Row 2: Replace Z with old_dep, youth_dep, working_age_share. "
                    "Reports old_dep.*")
    md_lines.append("- *Row 3: 5-year lagged demographics. Reports Z_1_lag5.*")
    md_lines.append("- *Row 4: First differences. Reports dZ_1.*")
    md_lines.append("- *Row 5: Add gross_investment_gdp control. Reports Z_1.*")
    md_lines.append("- *Row 6: Add gdp_pc_ppp control. Reports Z_1.*")
    md_lines.append("- *Row 7: Exclude Japan and Germany. Reports Z_1.*")
    md_lines.append("- *Row 8: Replace Z_1 with oadr_plus20 (predetermined). "
                    "Reports oadr_plus20.*")
    md_lines.append("- *Row 9: Add mi_index and Z_1*mi_index. Reports Z_1.*")
    md_lines.append("- *Controls: rgdp_growth, output_gap, fiscal_bal_gdp, kaopen, nfa_gdp_lag.*")

    md_t21 = "\n".join(md_lines)
    print("\n" + md_t21)
    t21_path = TABLE_DIR / "phase9_table21_robustness_inflation.md"
    t21_path.write_text(md_t21)
    print(f"\nSaved: {t21_path}")

    # ── Summary ──────────────────────────────────────────────────────────
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)

    print("\nTable 20 (real_bond_10y):")
    for spec_label, cells in rows_data:
        sig_count = sum(1 for c in cells if "***" in c or "**" in c)
        total = sum(1 for c in cells if c != "--")
        print(f"  {spec_label}: {sig_count}/{total} subsamples significant at 5%")

    print("\nTable 21 (inflation):")
    for spec_label, cells in rows_data_infl:
        sig_count = sum(1 for c in cells if "***" in c or "**" in c)
        total = sum(1 for c in cells if c != "--")
        print(f"  {spec_label}: {sig_count}/{total} subsamples significant at 5%")

    print("\nPhase 9 complete.")


if __name__ == "__main__":
    main()
